Support offloading encode, for generate() with much less VRAM#269
Support offloading encode, for generate() with much less VRAM#269drdaxxy wants to merge 1 commit intoborisdayma:mainfrom
Conversation
|
Sounds awesome! |
I don't have time to write a proper example now, sorry... I'm hoping another developer decides to take care of that. |
|
Could this even get the full one working on a much smaller GPU VRAM too ? the full mega checkpoint instead of just the fp? |
I guess so:
|
wait so my RTX 3060 should already be good to go for running this in something like Visions of Chaos? The full checkpoint? |
|
Those are very interesting ideas @drdaxxy ! I'm gonna try to think about how to integrate it in a clean way. |
generate() from Transformers can take encoder outputs as kwargs instead of running the encoder. This PR extends this to "super conditioning" sampling. It also enables providing only one "null sequence" per batch, as inputs or encoder state, since that prompt is normally constant.
How is this useful? We only need to run the encoder once per distinct prompt, which even on a household CPU takes 1-2 seconds for a single input (worst case, no batching, no reuse). Offloading this step, generate works without 2 or 4 gigabytes of encoder weights (mega-1 and mega-1-fp16, respectively) hogging VRAM.
That way, mega-1-fp16 can run on a 4GB GPU (1-batches, without VQGAN, which is fast enough on CPU) and full-precision mega-1 can run on an 8GB GPU (1-batches with VQGAN, up to 3-batches without).
Specifically, without VQGAN, 1-batches need 3728 MiB in float16, 6770 MiB in float32 this way. GPU-accelerating VQGAN adds 770 MiB, assuming we also
del vqgan_params["encoder"](we never need these for generating images) beforereplicate(vqgan_params)or the like.On systems that have enough memory anyway, up to 10 (fp32) or 20 (fp16) more items fit in a batch. Given the CPU encode cost, that's a few percent slower or faster (especially combined with other tricks in #247) in my experience, depending on how much state is shared.